{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f310313d",
   "metadata": {},
   "outputs": [],
   "source": [
    "deployment = \"gpt4\"\n",
    "model = \"gpt-4\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29f45ffa",
   "metadata": {},
   "source": [
    "# 回顾"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "dbc1a0e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.chains import LLMRequestsChain\n",
    "from langchain import PromptTemplate, OpenAI, LLMChain\n",
    "from langchain.chat_models import AzureChatOpenAI\n",
    "llm = AzureChatOpenAI(deployment_name = deployment, model_name=model, temperature=0, max_tokens=200)\n",
    "template = \"\"\"Between >>> and <<< are the raw search result text from web.\n",
    "Extract the answer to the question '{query}' or say \"not found\" if the information is not contained.\n",
    "Use the format\n",
    "Extracted:<answer or \"not found\">\n",
    ">>> {requests_result} <<<\n",
    "Extracted:\"\"\"\n",
    "\n",
    "PROMPT = PromptTemplate(\n",
    "  input_variables=[\"query\", \"requests_result\"],\n",
    "  template=template,\n",
    ")\n",
    "\n",
    "\n",
    "query_chain = LLMRequestsChain(llm_chain = LLMChain(llm=llm, prompt=PROMPT), output_key=\"query_info\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "1f2c1138",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain import PromptTemplate, OpenAI, LLMChain\n",
    "from langchain.chat_models import AzureChatOpenAI\n",
    "from langchain.chains import SequentialChain\n",
    "\n",
    "translating_prompt_template = \"\"\"\n",
    "translate \"{query_info}\" into English:\n",
    "\n",
    "\"\"\"\n",
    "\n",
    "translating_chain = LLMChain(\n",
    "    llm = llm,\n",
    "    prompt=PromptTemplate.from_template(translating_prompt_template),\n",
    "    output_key = \"translated\"\n",
    ")\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "5a427661",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new SequentialChain chain...\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'\"20°C with moderate rain, east wind at level 4, temperature ranging from 20~23°C\"'"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "llm = AzureChatOpenAI(deployment_name = deployment, model_name=model, temperature=0, max_tokens=200)\n",
    "\n",
    "def overall(question):\n",
    "    inputs = {\n",
    "      \"query\": question,\n",
    "      \"url\": \"http://www.baidu.com/s?wd=\" + question.replace(\" \", \"+\")\n",
    "      # \"url\": \"https://cn.bing.com/search?q=\" + question.replace(\" \", \"+\")\n",
    "    }\n",
    "    \n",
    "    overall_chain = SequentialChain(\n",
    "        chains=[query_chain, translating_chain],\n",
    "        input_variables=[\"query\",\"url\"],\n",
    "        output_variables=[\"translated\"],\n",
    "        verbose=True\n",
    "    )\n",
    "    \n",
    "    return overall_chain(inputs)[\"translated\"]\n",
    "\n",
    "overall(\"北京今天天气\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54124440",
   "metadata": {},
   "source": [
    "# 没有记忆的模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6df8a7d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import openai\n",
    "def get_response(input):\n",
    "    response = openai.ChatCompletion.create(\n",
    "        engine=deployment, # engine = \"deployment_name\".\n",
    "        messages=[\n",
    "            {\"role\": \"system\", \"content\": \"You are an AI assistant.\"},   \n",
    "            {\"role\": \"user\", \"content\": input}\n",
    "        ],\n",
    "        temperature = 0.9, \n",
    "        max_tokens = 200\n",
    "      )\n",
    "    return response.choices[0].message.content"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "0804701b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "作为一个AI助手，我并不会有特定的日程安排。我24小时全天候在线，随时都可以为你提供服务。无论你需要记事、设定闹钟、查找信息或其他任何帮助，我都将尽力为你解决。祝你今天的健身训练顺利！\n",
      "对不起，我无法回答这个问题。作为一个AI，我没有实时访问或跟踪您的个人日程或习惯。我的设计是为了保护您的隐私和数据安全。\n"
     ]
    }
   ],
   "source": [
    "print(get_response(\"你好，今天是周二我要去健身，我一般每周二健身。你今天干什么？\"))\n",
    "print(get_response(\"我一般周几健身？\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47d9a1d0",
   "metadata": {},
   "source": [
    "# 通过Gradio快速展示功能"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "af13c666",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: gradio in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (3.42.0)\n",
      "Requirement already satisfied: aiofiles<24.0,>=22.0 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from gradio) (22.1.0)\n",
      "Requirement already satisfied: altair<6.0,>=4.2.0 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from gradio) (5.1.1)\n",
      "Requirement already satisfied: fastapi in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from gradio) (0.99.1)\n",
      "Requirement already satisfied: ffmpy in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from gradio) (0.3.1)\n",
      "Requirement already satisfied: gradio-client==0.5.0 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from gradio) (0.5.0)\n",
      "Requirement already satisfied: httpx in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from gradio) (0.24.1)\n",
      "Requirement already satisfied: huggingface-hub>=0.14.0 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from gradio) (0.15.1)\n",
      "Requirement already satisfied: importlib-resources<7.0,>=1.3 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from gradio) (6.0.1)\n",
      "Requirement already satisfied: jinja2<4.0 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from gradio) (3.1.2)\n",
      "Requirement already satisfied: markupsafe~=2.0 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from gradio) (2.1.1)\n",
      "Requirement already satisfied: matplotlib~=3.0 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from gradio) (3.7.1)\n",
      "Requirement already satisfied: numpy~=1.0 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from gradio) (1.24.3)\n",
      "Requirement already satisfied: orjson~=3.0 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from gradio) (3.9.5)\n",
      "Requirement already satisfied: packaging in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from gradio) (23.0)\n",
      "Requirement already satisfied: pandas<3.0,>=1.0 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from gradio) (1.5.3)\n",
      "Requirement already satisfied: pillow<11.0,>=8.0 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from gradio) (9.4.0)\n",
      "Requirement already satisfied: pydantic!=1.8,!=1.8.1,!=2.0.0,!=2.0.1,<3.0.0,>=1.7.4 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from gradio) (1.10.12)\n",
      "Requirement already satisfied: pydub in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from gradio) (0.25.1)\n",
      "Requirement already satisfied: python-multipart in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from gradio) (0.0.6)\n",
      "Requirement already satisfied: pyyaml<7.0,>=5.0 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from gradio) (6.0)\n",
      "Requirement already satisfied: requests~=2.0 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from gradio) (2.31.0)\n",
      "Requirement already satisfied: semantic-version~=2.0 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from gradio) (2.10.0)\n",
      "Requirement already satisfied: typing-extensions~=4.0 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from gradio) (4.7.1)\n",
      "Requirement already satisfied: uvicorn>=0.14.0 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from gradio) (0.23.2)\n",
      "Requirement already satisfied: websockets<12.0,>=10.0 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from gradio) (11.0.3)\n",
      "Requirement already satisfied: fsspec in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from gradio-client==0.5.0->gradio) (2023.4.0)\n",
      "Requirement already satisfied: jsonschema>=3.0 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from altair<6.0,>=4.2.0->gradio) (4.17.3)\n",
      "Requirement already satisfied: toolz in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from altair<6.0,>=4.2.0->gradio) (0.12.0)\n",
      "Requirement already satisfied: filelock in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from huggingface-hub>=0.14.0->gradio) (3.9.0)\n",
      "Requirement already satisfied: tqdm>=4.42.1 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from huggingface-hub>=0.14.0->gradio) (4.65.0)\n",
      "Requirement already satisfied: contourpy>=1.0.1 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from matplotlib~=3.0->gradio) (1.0.5)\n",
      "Requirement already satisfied: cycler>=0.10 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from matplotlib~=3.0->gradio) (0.11.0)\n",
      "Requirement already satisfied: fonttools>=4.22.0 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from matplotlib~=3.0->gradio) (4.25.0)\n",
      "Requirement already satisfied: kiwisolver>=1.0.1 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from matplotlib~=3.0->gradio) (1.4.4)\n",
      "Requirement already satisfied: pyparsing>=2.3.1 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from matplotlib~=3.0->gradio) (3.0.9)\n",
      "Requirement already satisfied: python-dateutil>=2.7 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from matplotlib~=3.0->gradio) (2.8.2)\n",
      "Requirement already satisfied: pytz>=2020.1 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from pandas<3.0,>=1.0->gradio) (2022.7)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from requests~=2.0->gradio) (2.0.4)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from requests~=2.0->gradio) (3.4)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from requests~=2.0->gradio) (1.26.16)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from requests~=2.0->gradio) (2023.7.22)\n",
      "Requirement already satisfied: click>=7.0 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from uvicorn>=0.14.0->gradio) (8.0.4)\n",
      "Requirement already satisfied: h11>=0.8 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from uvicorn>=0.14.0->gradio) (0.14.0)\n",
      "Requirement already satisfied: starlette<0.28.0,>=0.27.0 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from fastapi->gradio) (0.27.0)\n",
      "Requirement already satisfied: httpcore<0.18.0,>=0.15.0 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from httpx->gradio) (0.17.3)\n",
      "Requirement already satisfied: sniffio in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from httpx->gradio) (1.2.0)\n",
      "Requirement already satisfied: anyio<5.0,>=3.0 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from httpcore<0.18.0,>=0.15.0->httpx->gradio) (3.5.0)\n",
      "Requirement already satisfied: attrs>=17.4.0 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (22.1.0)\n",
      "Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (0.18.0)\n",
      "Requirement already satisfied: six>=1.5 in /Users/mbj0593/anaconda3/lib/python3.11/site-packages (from python-dateutil>=2.7->matplotlib~=3.0->gradio) (1.16.0)\n"
     ]
    }
   ],
   "source": [
    "!pip install gradio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "13092b1c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running on local URL:  http://127.0.0.1:7860\n",
      "\n",
      "To create a public link, set `share=True` in `launch()`.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div><iframe src=\"http://127.0.0.1:7860/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": []
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gradio as gr\n",
    "def respond(message, chat_history):\n",
    "        bot_message = get_response(message)\n",
    "        chat_history.append((message, bot_message)) #保存历史对话记录，用于显示\n",
    "        return \"\", chat_history\n",
    "\n",
    "with gr.Blocks() as demo:\n",
    "    chatbot = gr.Chatbot(height=240) #对话框\n",
    "    msg = gr.Textbox(label=\"Prompt\") #输入框\n",
    "    btn = gr.Button(\"Submit\") #提交按钮\n",
    "    #提交\n",
    "    btn.click(respond, inputs=[msg, chatbot], outputs=[msg, chatbot])\n",
    "    msg.submit(respond, inputs=[msg, chatbot], outputs=[msg, chatbot]) \n",
    "gr.close_all()\n",
    "demo.launch()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "040d6f45",
   "metadata": {},
   "source": [
    "# 提供外部记忆"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "d525ac99",
   "metadata": {},
   "outputs": [],
   "source": [
    "import openai\n",
    "def get_response(msg):\n",
    "    # print(msg)\n",
    "    response = openai.ChatCompletion.create(\n",
    "        engine=deployment, # engine = \"deployment_name\".\n",
    "        messages=msg,\n",
    "        temperature = 0.9, \n",
    "        max_tokens = 600\n",
    "    )\n",
    "    return response.choices[0].message.content"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "74ee0211",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def history_to_prompt(chat_history): # 将对话内容保存在一个List里\n",
    "    msg = [{\"role\": \"system\", \"content\": \"You are an AI assistant.\"}]\n",
    "    i = 0\n",
    "    for round_trip in chat_history: # 将List里的内容，组成 ChatCompletion的 messages部分，{role，content} dict\n",
    "        msg.append({\"role\": \"user\", \"content\": round_trip[0]})\n",
    "        msg.append({\"role\": \"assistant\", \"content\": round_trip[1]})\n",
    "    return msg\n",
    "\n",
    "def respond(message, chat_history):\n",
    "    his_msg = history_to_prompt(chat_history) #并装历史会话，ChatCompletion的 messages部分格式\n",
    "    his_msg.append({\"role\": \"user\", \"content\": message}) # 放入当前用户问题\n",
    "    bot_message = get_response(his_msg)\n",
    "    chat_history.append((message, bot_message)) # 将用户问题和返回保存到 历史记录 List\n",
    "    return \"\", chat_history\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "492d60dd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running on local URL:  http://127.0.0.1:7863\n",
      "Running on public URL: https://53f21db72a477bb701.gradio.live\n",
      "\n",
      "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div><iframe src=\"https://53f21db72a477bb701.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": []
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gradio as gr\n",
    "with gr.Blocks() as demo:\n",
    "    chatbot = gr.Chatbot(height=480) #just to fit the notebook\n",
    "    msg = gr.Textbox(label=\"Prompt\")\n",
    "    btn = gr.Button(\"Submit\")\n",
    "    clear = gr.ClearButton(components=[msg, chatbot], value=\"Clear console\")\n",
    "\n",
    "    btn.click(respond, inputs=[msg, chatbot], outputs=[msg, chatbot])\n",
    "    msg.submit(respond, inputs=[msg, chatbot], outputs=[msg, chatbot]) #Press enter to submit\n",
    "gr.close_all()\n",
    "demo.launch(share=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d03d5452",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
